# SPDX-License-Identifier: Apache-2.0

import numpy as np  # type: ignore

import onnx

from ..base import Base
from . import expect


class ReduceLogSum(Base):
    @staticmethod
    def export_nokeepdims() -> None:
        node = onnx.helper.make_node(
            "ReduceLogSum",
            inputs=["data"],
            outputs=["reduced"],
            axes=[2, 1],
            keepdims=0,
        )
        data = np.random.ranf([3, 4, 5]).astype(np.float32)
        reduced = np.log(np.sum(data, axis=(2, 1), keepdims=False))
        expect(
            node, inputs=[data], outputs=[reduced], name="test_reduce_log_sum_desc_axes"
        )

        node = onnx.helper.make_node(
            "ReduceLogSum",
            inputs=["data"],
            outputs=["reduced"],
            axes=[0, 1],
            keepdims=0,
        )
        data = np.random.ranf([3, 4, 5]).astype(np.float32)
        reduced = np.log(np.sum(data, axis=(0, 1), keepdims=False))
        expect(
            node, inputs=[data], outputs=[reduced], name="test_reduce_log_sum_asc_axes"
        )

    @staticmethod
    def export_keepdims() -> None:
        node = onnx.helper.make_node(
            "ReduceLogSum", inputs=["data"], outputs=["reduced"]
        )
        data = np.random.ranf([3, 4, 5]).astype(np.float32)
        reduced = np.log(np.sum(data, keepdims=True))
        expect(
            node, inputs=[data], outputs=[reduced], name="test_reduce_log_sum_default"
        )

    @staticmethod
    def export_negative_axes_keepdims() -> None:
        node = onnx.helper.make_node(
            "ReduceLogSum", inputs=["data"], outputs=["reduced"], axes=[-2]
        )
        data = np.random.ranf([3, 4, 5]).astype(np.float32)
        reduced = np.log(np.sum(data, axis=(-2), keepdims=True))
        # print(reduced)
        expect(
            node,
            inputs=[data],
            outputs=[reduced],
            name="test_reduce_log_sum_negative_axes",
        )
